#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://spdx.org/licenses/BSD-3-Clause.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions that handle saving and loading of checkpoints."""

import os
from shutil import copyfileobj

import pycls.core.distributed as dist
import torch
from iopath.common.file_io import g_pathmgr
from pycls.core.config import cfg
from pycls.core.net import unwrap_model


# Common prefix for checkpoint file names
_NAME_PREFIX = "model_epoch_"

# Checkpoints directory name
_DIR_NAME = "checkpoints"


def get_checkpoint_dir():
    """Retrieves the location for storing checkpoints."""
    return os.path.join(cfg.OUT_DIR, _DIR_NAME)


def get_checkpoint(epoch):
    """Retrieves the path to a checkpoint file."""
    name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
    return os.path.join(get_checkpoint_dir(), name)


def get_checkpoint_best():
    """Retrieves the path to the best checkpoint file."""
    return os.path.join(cfg.OUT_DIR, "model.pyth")


def get_last_checkpoint():
    """Retrieves the most recent checkpoint (highest epoch number)."""
    checkpoint_dir = get_checkpoint_dir()
    checkpoints = [f for f in g_pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f]
    last_checkpoint_name = sorted(checkpoints)[-1]
    return os.path.join(checkpoint_dir, last_checkpoint_name)


def has_checkpoint():
    """Determines if there are checkpoints available."""
    checkpoint_dir = get_checkpoint_dir()
    if not g_pathmgr.exists(checkpoint_dir):
        return False
    return any(_NAME_PREFIX in f for f in g_pathmgr.ls(checkpoint_dir))


def save_checkpoint(model, optimizer, epoch, best):
    """Saves a checkpoint."""
    # Save checkpoints only from the master process
    if not dist.is_master_proc():
        return
    g_pathmgr.mkdirs(get_checkpoint_dir())
    # Ensure that the checkpoint dir exists
    checkpoint_file = ""
    if epoch % 20 == 0 or epoch == 1:
        # Record the state
        checkpoint = {
            "epoch": epoch,
            "model_state": unwrap_model(model).state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "cfg": cfg.dump(),
        }
        # Write the checkpoint
        checkpoint_file = get_checkpoint(epoch + 1)
        with g_pathmgr.open(checkpoint_file, "wb") as f:
            torch.save(checkpoint, f)
        # If best copy checkpoint to the best checkpoint
        if best:
            with g_pathmgr.open(checkpoint_file, "rb") as src:
                with g_pathmgr.open(get_checkpoint_best(), "wb") as dst:
                    copyfileobj(src, dst)
    elif best:
        checkpoint = {
            "epoch": epoch,
            "model_state": unwrap_model(model).state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "cfg": cfg.dump(),
        }
        with g_pathmgr.open(get_checkpoint_best(), "wb") as f:
            torch.save(checkpoint, f)
        checkpoint_file = get_checkpoint(epoch + 1)
    return checkpoint_file


def load_checkpoint(checkpoint_file, model, optimizer=None):
    """Loads the checkpoint from the given file."""
    err_str = "Checkpoint '{}' not found"
    assert g_pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file)
    with g_pathmgr.open(checkpoint_file, "rb") as f:
        checkpoint = torch.load(f, map_location="cpu")
    unwrap_model(model).load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else ()
    return checkpoint["epoch"]


def pretrained_load_checkpoint(checkpoint_file, model, optimizer=None):
    """Loads the checkpoint from the given file."""
    err_str = "Checkpoint '{}' not found"
    assert g_pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file)
    with g_pathmgr.open(checkpoint_file, "rb") as f:
        checkpoint = torch.load(f, map_location="cpu")
        checkpoint["model_state"].pop("head.fc.weight")
        checkpoint["model_state"].pop("head.fc.bias")
    unwrap_model(model).load_state_dict(checkpoint["model_state"], strict=False)
    optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else ()
    return checkpoint["epoch"]


def delete_checkpoints(checkpoint_dir=None, keep="all"):
    """Deletes unneeded checkpoints, keep can be "all", "last", or "none"."""
    assert keep in ["all", "last", "none"], "Invalid keep setting: {}".format(keep)
    checkpoint_dir = checkpoint_dir if checkpoint_dir else get_checkpoint_dir()
    if keep == "all" or not g_pathmgr.exists(checkpoint_dir):
        return 0
    checkpoints = [f for f in g_pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f]
    checkpoints = sorted(checkpoints)[:-1] if keep == "last" else checkpoints
    [
        g_pathmgr.rm(os.path.join(checkpoint_dir, checkpoint))
        for checkpoint in checkpoints
    ]
    return len(checkpoints)
